Skip to content

Use a single device/dtype#86

Merged
SCiarella merged 10 commits intomainfrom
fixdevicedtype
Feb 11, 2026
Merged

Use a single device/dtype#86
SCiarella merged 10 commits intomainfrom
fixdevicedtype

Conversation

@SCiarella
Copy link
Collaborator

Closes #84

@SCiarella SCiarella marked this pull request as ready for review February 3, 2026 12:48
Copy link
Collaborator

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SCiarella thanks for the fixes! 👍 please see my suggestions and also this issue #87. Sorry if it requires more changes than expected. We can also discuss the issue offline if needed.

@SCiarella
Copy link
Collaborator Author

This PR now also closes #87.

Following the suggestion of @ronvree ComputeConfig is now initialized to the current default pytorch values, such that it is easier for the user to avoid type/device conflicts.

@ronvree
Copy link

ronvree commented Feb 5, 2026

This PR now also closes #87.

Following the suggestion of @ronvree ComputeConfig is now initialized to the current default pytorch values, such that it is easier for the user to avoid type/device conflicts.

Perfect thanks a lot! Sorry to be so difficult about this but I had one more concern in our previous discussion that I feel has not been addressed, namely that ComputeConfig stores a dtype/device as global/class properties which, if I understand correctly, prevents a user from having two instances on different devices. What are your thoughts about making device/dtype tied to model instances? Thanks again for picking this up!

@SCiarella
Copy link
Collaborator Author

Hi @ronvree, thanks for the suggestion!

Yes it makes sense to allow different instances to have different device/dtype.
I have implemented this feature by storing the ComputeConfig setting in each instance during initialization, so now you can do something like

# Create model on GPU 0
ComputeConfig.set_device("cuda:0")
model1 = Model() # captures cuda:0

# Create model on GPU 1
ComputeConfig.set_device("cuda:1")
model2 = Model() # captures cuda:1

while retaining all the advantages of having a global container to synchronize the different models and submodels.

I have also added test_config.py::TestComputeConfig::test_models_capture_config_at_initialization to test this feature and confirm that it works as expected.

@ronvree
Copy link

ronvree commented Feb 6, 2026

@SCiarella Perfect thanks a lot!!

Copy link
Collaborator

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SCiarella thanks! looks very nice 💯 just a small comment, feel free to merge.

SCiarella and others added 3 commits February 10, 2026 14:13
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
@sonarqubecloud
Copy link

@SCiarella SCiarella merged commit 2eddd79 into main Feb 11, 2026
11 checks passed
@SCiarella SCiarella deleted the fixdevicedtype branch February 11, 2026 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Task]: Fix device and dtype

3 participants